{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import gym\n",
    "#import math\n",
    "import numpy as np\n",
    "from collections import deque\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, Flatten\n",
    "from tensorflow.keras.optimizers import Adam"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCHS = 1000\n",
    "THRESHOLD = 10\n",
    "MONITOR = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DQN():\n",
    "    def __init__(self, env_string,batch_size=64, IM_SIZE = 84, m = 4):\n",
    "        self.memory = deque(maxlen=5000)\n",
    "        self.env = gym.make(env_string)\n",
    "        input_size = self.env.observation_space.shape[0]\n",
    "        action_size = self.env.action_space.n\n",
    "        self.batch_size = batch_size\n",
    "        self.gamma = 1.0\n",
    "        self.epsilon = 1.0\n",
    "        self.epsilon_min = 0.01\n",
    "        self.epsilon_decay = 0.995\n",
    "        self.IM_SIZE = IM_SIZE\n",
    "        self.m = m\n",
    "       \n",
    "        \n",
    "        alpha=0.01\n",
    "        alpha_decay=0.01\n",
    "        if MONITOR: self.env = gym.wrappers.Monitor(self.env, '../data/'+env_string, force=True)\n",
    "        \n",
    "        # Init model\n",
    "        self.model = Sequential()\n",
    "        self.model.add( Conv2D(32, 8, (4,4), activation='relu',padding='valid', input_shape=(IM_SIZE, IM_SIZE, m)))\n",
    "        #self.model.add(MaxPooling2D())\n",
    "        self.model.add( Conv2D(64, 4, (2,2), activation='relu',padding='valid'))\n",
    "        self.model.add(MaxPooling2D())\n",
    "        self.model.add( Conv2D(64, 3, (1,1), activation='relu',padding='valid'))\n",
    "        self.model.add(MaxPooling2D())\n",
    "        self.model.add(Flatten())\n",
    "        self.model.add(Dense(256, activation='elu'))\n",
    "        self.model.add(Dense(action_size, activation='linear'))\n",
    "        self.model.compile(loss='mse', optimizer=Adam(lr=alpha, decay=alpha_decay))\n",
    "        self.model_target = tf.keras.models.clone_model(self.model)\n",
    "\n",
    "    def remember(self, state, action, reward, next_state, done):\n",
    "        self.memory.append((state, action, reward, next_state, done))\n",
    "\n",
    "    def choose_action(self, state, epsilon):\n",
    "        if np.random.random() <= epsilon:\n",
    "            return self.env.action_space.sample()\n",
    "        else:\n",
    "            return np.argmax(self.model.predict(state))\n",
    "\n",
    "    def preprocess_state(self, img):\n",
    "        img_temp = img[31:195]  # Choose the important area of the image\n",
    "        img_temp = tf.image.rgb_to_grayscale(img_temp)\n",
    "        img_temp = tf.image.resize(img_temp, [self.IM_SIZE, self.IM_SIZE],\n",
    "                                   method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
    "        img_temp = tf.cast(img_temp, tf.float32)\n",
    "        return img_temp[:,:,0]\n",
    "\n",
    "    def combine_images(self, img1, img2):\n",
    "        if len(img1.shape) == 3 and img1.shape[0] == self.m:\n",
    "            im = np.append(img1[1:,:, :],np.expand_dims(img2,0), axis=2)\n",
    "            return tf.expand_dims(im, 0)\n",
    "        else:\n",
    "            im = np.stack([img1]*self.m, axis = 2)\n",
    "            return tf.expand_dims(im, 0)\n",
    "        #return np.reshape(state, [1, 4])\n",
    "\n",
    "    def replay(self, batch_size):\n",
    "        x_batch, y_batch = [], []\n",
    "        minibatch = random.sample(self.memory, min(len(self.memory), batch_size))\n",
    "        for state, action, reward, next_state, done in minibatch:\n",
    "            y_target = self.model_target.predict(state)\n",
    "            y_target[0][action] = reward if done else reward + self.gamma * np.max(self.model.predict(next_state)[0])\n",
    "            x_batch.append(state[0])\n",
    "            y_batch.append(y_target[0])\n",
    "        \n",
    "        self.model.fit(np.array(x_batch), np.array(y_batch), batch_size=len(x_batch), verbose=0)\n",
    "        #epsilon = max(epsilon_min, epsilon_decay*epsilon) # decrease epsilon\n",
    "       \n",
    "\n",
    "    def train(self):\n",
    "        scores = deque(maxlen=100)\n",
    "        avg_scores = []\n",
    "        \n",
    "\n",
    "        for e in range(EPOCHS):\n",
    "            state = self.env.reset()\n",
    "            state = self.preprocess_state(state)\n",
    "            state = self.combine_images(state, state)\n",
    "            done = False\n",
    "            i = 0\n",
    "            while not done:\n",
    "                action = self.choose_action(state,self.epsilon)\n",
    "                next_state, reward, done, _ = self.env.step(action)\n",
    "                next_state = self.preprocess_state(next_state)\n",
    "                next_state = self.combine_images(next_state, state)\n",
    "                #print(next_state.shape)\n",
    "                self.remember(state, action, reward, next_state, done)\n",
    "                state = next_state\n",
    "                self.epsilon = max(self.epsilon_min, self.epsilon_decay*self.epsilon) # decrease epsilon\n",
    "                i += reward\n",
    "\n",
    "            scores.append(i)\n",
    "            mean_score = np.mean(scores)\n",
    "            avg_scores.append(mean_score)\n",
    "            if mean_score >= THRESHOLD:\n",
    "                print('Solved after {} trials ✔'.format(e))\n",
    "                return avg_scores\n",
    "            if e % 10 == 0:\n",
    "                print('[Episode {}] - Average Score: {}.'.format(e, mean_score))\n",
    "                self.model_target.set_weights(self.model.get_weights())\n",
    "\n",
    "            self.replay(self.batch_size)\n",
    "        \n",
    "        print('Did not solve after {} episodes 😞'.format(e))\n",
    "        return avg_scores\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "env_string = 'BreakoutDeterministic-v4'\n",
    "agent = DQN(env_string)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d (Conv2D)              (None, 20, 20, 32)        8224      \n",
      "_________________________________________________________________\n",
      "conv2d_1 (Conv2D)            (None, 9, 9, 64)          32832     \n",
      "_________________________________________________________________\n",
      "max_pooling2d (MaxPooling2D) (None, 4, 4, 64)          0         \n",
      "_________________________________________________________________\n",
      "conv2d_2 (Conv2D)            (None, 2, 2, 64)          36928     \n",
      "_________________________________________________________________\n",
      "max_pooling2d_1 (MaxPooling2 (None, 1, 1, 64)          0         \n",
      "_________________________________________________________________\n",
      "flatten (Flatten)            (None, 64)                0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 256)               16640     \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 4)                 1028      \n",
      "=================================================================\n",
      "Total params: 95,652\n",
      "Trainable params: 95,652\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "agent.model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d (Conv2D)              (None, 20, 20, 32)        8224      \n",
      "_________________________________________________________________\n",
      "conv2d_1 (Conv2D)            (None, 9, 9, 64)          32832     \n",
      "_________________________________________________________________\n",
      "max_pooling2d (MaxPooling2D) (None, 4, 4, 64)          0         \n",
      "_________________________________________________________________\n",
      "conv2d_2 (Conv2D)            (None, 2, 2, 64)          36928     \n",
      "_________________________________________________________________\n",
      "max_pooling2d_1 (MaxPooling2 (None, 1, 1, 64)          0         \n",
      "_________________________________________________________________\n",
      "flatten (Flatten)            (None, 64)                0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 256)               16640     \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 4)                 1028      \n",
      "=================================================================\n",
      "Total params: 95,652\n",
      "Trainable params: 95,652\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "agent.model_target.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Episode 0] - Average Score: 0.0.\n",
      "[Episode 10] - Average Score: 0.09090909090909091.\n",
      "[Episode 20] - Average Score: 0.14285714285714285.\n"
     ]
    }
   ],
   "source": [
    "scores = agent.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.plot(scores)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent.env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:new_tf]",
   "language": "python",
   "name": "conda-env-new_tf-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
